{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "#device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(Dataset):\n",
    "    def __init__(self, files, labels, mode, k):\n",
    "        super().__init__()\n",
    "        self.files = list(files * k)\n",
    "        self.labels = torch.tensor(labels * k, dtype=torch.float32)\n",
    "        self.mode = mode\n",
    "        self.len_ = len(self.files)\n",
    "                      \n",
    "    def __len__(self):\n",
    "        return self.len_\n",
    "      \n",
    "    def load_sample(self, file):\n",
    "        image = Image.open(file)\n",
    "        image.load()\n",
    "        return image\n",
    "  \n",
    "    def __getitem__(self, index):\n",
    "        transform = transforms.Compose([\n",
    "            transforms.Grayscale(),\n",
    "            transforms.ToTensor(),\n",
    "            #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) \n",
    "        ])\n",
    "\n",
    "        train_transforms = transforms.Compose([\n",
    "            #transforms.RandomCrop((30, 30)),\n",
    "            transforms.RandomResizedCrop((30, 30), (0.8, 1.1))\n",
    "        ])\n",
    "\n",
    "        test_transforms = transforms.Compose([\n",
    "            transforms.CenterCrop((30, 30))\n",
    "        ])\n",
    "        x = self.load_sample(self.files[index])\n",
    "        \n",
    "        if self.mode == 'test':\n",
    "            x = test_transforms(x)\n",
    "        else:\n",
    "            x = train_transforms(x)\n",
    "        x = transform(x)\n",
    "        y = self.labels[index]\n",
    "        return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = []\n",
    "labels = []\n",
    "for i in range(14):\n",
    "    files.append(f'dataset/{i}/img.png')\n",
    "    label = [0 for j in range(14)]\n",
    "    label[i] = 1\n",
    "    labels.append(label)\n",
    "train_dataset = MyDataset(files, labels, 'train', 400)\n",
    "test_dataset = MyDataset(files, labels, 'test', 20)\n",
    "train_loader = DataLoader(train_dataset, len(train_dataset), shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, len(test_dataset), shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nmodel = nn.Sequential(\\n    nn.Flatten(),\\n    nn.Linear(30 * 30, 128),\\n    nn.ReLU(),\\n    nn.Linear(128, 64),\\n    nn.ReLU(),\\n    nn.Linear(64, 14),\\n    nn.Softmax(dim=1)\\n).to(device)'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = nn.Sequential(\n",
    "    nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),\n",
    "    nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "    nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(16 * 81, 128),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(128, 80),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(80, 14),\n",
    "    nn.Softmax(dim=1)\n",
    ").to(device)\n",
    "\"\"\"\n",
    "model = nn.Sequential(\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(30 * 30, 128),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(128, 64),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(64, 14),\n",
    "    nn.Softmax(dim=1)\n",
    ").to(device)\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = torch.load('model/v3').to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
    "criterion = nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "losses_tr = []\n",
    "losses_te = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/600 [00:00<00:44, 13.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0003641020739451051 1.970412085938733e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 32/600 [00:02<00:41, 13.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.00017527485033497214 7.941197509353515e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 62/600 [00:04<00:39, 13.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.00013690149353351444 9.545492503093556e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 92/600 [00:06<00:39, 12.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.00010721362195909023 1.2570008948387112e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 122/600 [00:09<00:35, 13.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.08697150205262e-05 1.6964031601673923e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 152/600 [00:11<00:33, 13.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.518440932268277e-05 1.7977388779399917e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 182/600 [00:13<00:31, 13.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.5659242207184434e-05 1.5467163393623196e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▌      | 212/600 [00:15<00:28, 13.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.8896847147261724e-05 1.2713462638203055e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 242/600 [00:18<00:26, 13.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.374516720417887e-05 1.0662631211744156e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 272/600 [00:20<00:23, 13.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.96366449422203e-05 9.105609024118166e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 302/600 [00:22<00:21, 13.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.62380851584021e-05 7.927233127702493e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|█████▌    | 332/600 [00:24<00:19, 13.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.337075759191066e-05 7.001455742283724e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 362/600 [00:26<00:17, 13.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.0902559956302866e-05 6.26860082775238e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████▌   | 392/600 [00:29<00:15, 13.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8736852982547134e-05 5.676927230524598e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 422/600 [00:31<00:12, 13.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.6819425329449587e-05 5.167404196981806e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|███████▌  | 452/600 [00:33<00:10, 13.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.510730701033026e-05 4.7278599595301785e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 482/600 [00:35<00:08, 13.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3569115001009777e-05 4.351290954218712e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 85%|████████▌ | 512/600 [00:37<00:06, 13.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2176345737534575e-05 4.019746938865865e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 542/600 [00:40<00:04, 13.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0906329154968262e-05 3.7256208997860085e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████▌| 572/600 [00:42<00:02, 13.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9745028112083673e-05 3.4451866213203175e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 600/600 [00:44<00:00, 13.53it/s]\n"
     ]
    }
   ],
   "source": [
    "for x_train, y_train in train_loader:\n",
    "    x_train = x_train.to(device)\n",
    "    y_train = y_train.to(device)\n",
    "    break\n",
    "for x_test, y_test in test_loader:\n",
    "    x_test = x_test.to(device)\n",
    "    y_test = y_test.to(device)\n",
    "    break\n",
    "for epoch in tqdm.tqdm(range(600)):\n",
    "    optim.zero_grad()\n",
    "    model.train()\n",
    "    #for x, y in train_loader:\n",
    "        #print(x.shape)\n",
    "    #x = x.to(device)\n",
    "    preds = model.forward(x_train)\n",
    "    loss = criterion(preds, y_train)\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    model.eval()\n",
    "    torch.no_grad()\n",
    "    preds =  model.forward(x_test)\n",
    "    loss_test = criterion(preds, y_test)\n",
    "    if epoch % 30 == 0:\n",
    "        print(float(loss.detach()), float(loss_test.detach()))\n",
    "    losses_tr.append(float(loss.detach()))\n",
    "    losses_te.append(float(loss_test.detach()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x22ccbabe6a0>"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOHklEQVR4nO3df4hd9ZnH8fdjtBgTQxJ/ZGOamCpqdhE3LiEsuIpaFZVqFGmoQs1K2RSsULF/rLiC/qWyqZaSP8RxE5quWa0/1wiyVqUgEhGjZI0xbuuGaKcZJwZLtOCP/Hj2j7mR2Tjn3PHOvXOu+b5fMMy957nnnMdjPnPOud977onMRNLh74imG5A0OQy7VAjDLhXCsEuFMOxSIQy7VIgjJzJzRFwK/BKYAvxbZt7T5vWO80k9lpkx1vTodJw9IqYAvwcuBgaB14BrM/PtmnkMu9RjVWGfyGH8UuDdzNyemV8AjwDLJrA8ST00kbDPA/446vlga5qkPjSRc/axDhW+cpgeESuBlRNYj6QumEjYB4H5o55/G9h56IsycwAYAM/ZpSZN5DD+NeC0iPhORHwL+AGwoTttSeq2jvfsmbkvIm4CnmNk6G1tZm7tWmeSuqrjobeOVuZhvNRzvRh6k/QNYtilQhh2qRCGXSqEYZcKYdilQhh2qRCGXSqEYZcKYdilQhh2qRCGXSqEYZcKYdilQhh2qRCGXSqEYZcKYdilQhh2qRCGXSrEhG7sqK/vrLPOqq0vWrSosvb4449X1g4cOFBZu+GGG2rXecopp9TWq0yfPr22vnDhwsra/v37K2s7d37l9gNf2rNnT+06H3jggcra4OBg7byHO/fsUiEMu1QIwy4VwrBLhTDsUiEMu1SICd3rLSJ2AJ8A+4F9mbmkzeuLuNfbEUdU/w295pprauddvXp1Ze3DDz+srNX9f5w7d27tOo8++ujaepUjj6wfua1bbl2/X3zxRWVt7969teu84IILKmubNm2qnfdwUXWvt26Ms1+Qmbu7sBxJPeRhvFSIiYY9gd9GxOsRsbIbDUnqjYkexp+TmTsj4kTg+Yh4JzNfGv2C1h8B/xBIDZvQnj0zd7Z+7wKeApaO8ZqBzFzS7s07Sb3VcdgjYlpEHHvwMXAJ8Fa3GpPUXRM5jJ8DPBURB5fzH5n5X13pSlLXdRz2zNwO/G0Xezls1F3aeeONN9bOO2fOnI5qdePW7T5LUXd57DvvvNPxcjtVNz4/b9682nnrPuNQOreMVAjDLhXCsEuFMOxSIQy7VAjDLhXCb5ftgbpLP2fMmNHxct9+++3K2tatWytr+/btq13uZ599VlmrGyqsm28irrjiisraqlWrerLOErhnlwph2KVCGHapEIZdKoRhlwph2KVCOPTWA3VXbbX7ptc6Dz30UGXt7rvv7ni5TZg6dWpl7corr6ysnXHGGb1opwju2aVCGHapEIZdKoRhlwph2KVCGHapEA699cBEht7qbt74wQcfdNxTE1rfPDymiy66qLJ27rnnVtb27NlTu866m0KWzj27VAjDLhXCsEuFMOxSIQy7VAjDLhXCsEuFaDvOHhFrge8BuzLzzNa02cBvgIXADmB5Zv65d21+swwPD1fW1qxZUzvve++9V1l77rnnOu6pCQsWLKisLV++vLJW91mEu+66q3ad27dvb99YocazZ/8VcOkh024FXszM04AXW88l9bG2Yc/Ml4CPDpm8DFjXerwOuKrLfUnqsk4/LjsnM4cAMnMoIk6semFErARWdrgeSV3S88/GZ+YAMAAQEdnr9UkaW6fvxg9HxFyA1u9d3WtJUi90GvYNwIrW4xXA091pR1KvRGb9kXVEPAycDxwPDAN3AP8JPAosAN4Hvp+Zh76JN9ayijiMr7ux48yZM2vn3bt3b2Wt3eWdk63uvxPguuuuq6ytXr26srZhw4bK2u233167zrqhy1Jk5pjXFrc9Z8/MaytK351QR5ImlZ+gkwph2KVCGHapEIZdKoRhlwrht8v2wL59+ypru3fvnsROemvp0qW19VWrVlXW6q4MfPTRRytr77//fvvGNCb37FIhDLtUCMMuFcKwS4Uw7FIhDLtUCIfeVGv69OmVtQsvvLB23hNOOKGyNjAwUFl75pln2jemr809u1QIwy4VwrBLhTDsUiEMu1QIwy4VwrBLhXCcXbXfErts2bLK2h133FG73K1bt1bWNm7c2L4xdZV7dqkQhl0qhGGXCmHYpUIYdqkQhl0qxHhu7LgW+B6wKzPPbE27E/gn4MPWy27LzGfbrqyQGzt+00ybNq2y9vLLL1fWZsyYUbvcW265pbL29NPe+LdXqm7sOJ49+6+AS8eY/ovMXNz6aRt0Sc1qG/bMfAloeztmSf1tIufsN0XEmxGxNiJmda0jST3RadjvB04FFgNDwL1VL4yIlRGxKSI2dbguSV3QUdgzczgz92fmAeBBoPI+QJk5kJlLMnNJp01KmriOwh4Rc0c9vRp4qzvtSOqVtle9RcTDwPnA8RExCNwBnB8Ri4EEdgA/7mGP6rFLLrmksrZo0aLK2iuvvFK73GefdZCmn7QNe2ZeO8bkNT3oRVIP+Qk6qRCGXSqEYZcKYdilQhh2qRCGXSpE20tcu7oyL3FtRN23x0L9ZawLFiyorF1//fW1y33hhRfqG1NPTOQSV0mHAcMuFcKwS4Uw7FIhDLtUCMMuFcIbOxbg4osvrq2ffPLJlbXdu3dX1upu3Kj+455dKoRhlwph2KVCGHapEIZdKoRhlwrh0FsB6r49FmDq1KmVtUceeaSyNjQ01HFPmnzu2aVCGHapEIZdKoRhlwph2KVCGHapEOO5seN84NfAXwEHgIHM/GVEzAZ+Ayxk5OaOyzPzz71rVXVmzpxZWVuypP5u2VOmTKmsrV+/vuOe1F/Gs2ffB/wsM/8a+HvgJxHxN8CtwIuZeRrwYuu5pD7VNuyZOZSZb7QefwJsA+YBy4B1rZetA67qVZOSJu5rnbNHxELgbOBVYE5mDsHIHwTgxG43J6l7xv1x2YiYDjwB3JyZH0eM+T30Y823EljZWXuSumVce/aIOIqRoK/PzCdbk4cjYm6rPhfYNda8mTmQmUsys/5dIkk91TbsMbILXwNsy8z7RpU2ACtaj1cAT3e/PUndMp7D+HOAHwJbImJza9ptwD3AoxHxI+B94Pu9aVFSN7QNe2a+DFSdoH+3u+2oU5dddlllre7mjAAbN26srA0PD3fck/qLn6CTCmHYpUIYdqkQhl0qhGGXCmHYpUL47bKHifPOO6+ydtJJJ9XO+9hjj1XW9u/f33FP6i/u2aVCGHapEIZdKoRhlwph2KVCGHapEA69fYMcc8wxlbVZs2ZV1o480v/Ncs8uFcOwS4Uw7FIhDLtUCMMuFcKwS4VwTOYb5NNPP62sDQ4OdjQfwOzZsytr470ZiPqfe3apEIZdKoRhlwph2KVCGHapEIZdKsR47uI6PyJ+FxHbImJrRPy0Nf3OiPhTRGxu/Vze+3YldWo84+z7gJ9l5hsRcSzwekQ836r9IjN/3rv2NFpmVtbuvffeytrMmTNrl3v66adX1uouj/38889rl6v+Mp67uA4BQ63Hn0TENmBerxuT1F1f65w9IhYCZwOvtibdFBFvRsTaiKj+9gRJjRt32CNiOvAEcHNmfgzcD5wKLGZkzz/mcWRErIyITRGxqQv9SurQuMIeEUcxEvT1mfkkQGYOZ+b+zDwAPAgsHWvezBzIzCWZuaRbTUv6+sbzbnwAa4BtmXnfqOlzR73sauCt7rcnqVvG8278OcAPgS0Rsbk17Tbg2ohYDCSwA/hxTzqU1BVRN5zT9ZVFTN7K9KX58+fX1o877rjK2pYtWypr3vSxP2XmmNcl+wk6qRCGXSqEYZcKYdilQhh2qRCGXSqEQ2/SYcahN6lwhl0qhGGXCmHYpUIYdqkQhl0qhGGXCmHYpUIYdqkQhl0qhGGXCmHYpUIYdqkQhl0qhGGXCmHYpUIYdqkQhl0qhGGXCmHYpUIYdqkQ47mLazftBt4b9fz41rR+YT/1+q0f6L+emu7n5KrCpH6V9FdWHrEpM5c01sAh7Kdev/UD/ddTv/UzmofxUiEMu1SIpsM+0PD6D2U/9fqtH+i/nvqtny81es4uafI0vWeXNEkaCXtEXBoR/xMR70bErU30cEg/OyJiS0RsjohNDfWwNiJ2RcRbo6bNjojnI+IPrd+zGu7nzoj4U2s7bY6Iyyexn/kR8buI2BYRWyPip63pjWyjmn4a20btTPphfERMAX4PXAwMAq8B12bm25PayP/vaQewJDMbGx+NiPOAvwC/zswzW9P+FfgoM+9p/VGclZn/3GA/dwJ/ycyfT0YPh/QzF5ibmW9ExLHA68BVwD/SwDaq6Wc5DW2jdprYsy8F3s3M7Zn5BfAIsKyBPvpKZr4EfHTI5GXAutbjdYz8Y2qyn8Zk5lBmvtF6/AmwDZhHQ9uopp++1UTY5wF/HPV8kOY3UgK/jYjXI2Jlw72MNiczh2DkHxdwYsP9ANwUEW+2DvMn7bRitIhYCJwNvEofbKND+oE+2EZjaSLsY90ovukhgXMy8++Ay4CftA5h9VX3A6cCi4Eh4N7JbiAipgNPADdn5seTvf5x9NP4NqrSRNgHgfmjnn8b2NlAH1/KzJ2t37uApxg51egHw61zw4PniLuabCYzhzNzf2YeAB5kkrdTRBzFSLDWZ+aTrcmNbaOx+ml6G9VpIuyvAadFxHci4lvAD4ANDfQBQERMa73BQkRMAy4B3qqfa9JsAFa0Hq8Anm6wl4NhOuhqJnE7RUQAa4BtmXnfqFIj26iqnya3UVuZOek/wOWMvCP/v8C/NNHDqF5OAf679bO1qX6Ahxk57NvLyNHPj4DjgBeBP7R+z264n38HtgBvMhKyuZPYzz8wcrr3JrC59XN5U9uopp/GtlG7Hz9BJxXCT9BJhTDsUiEMu1QIwy4VwrBLhTDsUiEMu1QIwy4V4v8AHL9C6fGx+fIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(np.squeeze(np.rollaxis(x_train[30].cpu().numpy(), 0, 3)), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5600, 1, 30, 30])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAATBklEQVR4nO3dbYxcV33H8e9/7sx680gCMSTYMQ6qVWRVpElXIWkqWkqhcVrVvKqCyqNAViQiSivUBiFVqvqqLa2qqGncCFKVthAhSloLGRJEC1Tiod60NCQhTpwHyOIEOwnEIYm9u95/X9w769n1Or5r72bsM9+PNJqZe++ZPWfW/vn4nHPvjcxEklSuzrArIElaXQa9JBXOoJekwhn0klQ4g16SCtcddgWWcsEFF+TGjRuHXQ1JOm3cfffdT2Xm2qX2nZJBv3HjRiYnJ4ddDUk6bUTED461z6EbSSqcQS9JhTPoJalwBr0kFc6gl6TCGfSSVDiDXpIKV1TQ3/TVh/j6g/uHXQ1JOqUUFfR///WH+YZBL0kLFBX0a3oVh2YPD7saknRKKSrox7sdDs7MDbsaknRKKSro6x69QS9Jg8oK+m6HgzMO3UjSoLKC3h69JB2lqKAft0cvSUcpKujt0UvS0YoK+vFuh0P26CVpgaKC3h69JB2tqKB3jF6SjlZU0K/pdezRS9IiRQX9eLeyRy9JixQV9PboJeloRQX9eLfi8Fwyc9iwl6S+ooJ+Ta9ujr16STqiqKAf71UAjtNL0oCign5N1x69JC1WVNDbo5ekoxUV9L2qbo6TsZJ0RJlBP5tDrokknTqKCvqxZox+2h69JM0rKuh7VQAO3UjSoKKCfswxekk6SlFB72SsJB2tyKCfdjJWkua1CvqIuCYidkfEnoi4cYn9vxcR9zSPb0bEpW3LrqSxbj1G72SsJB1x3KCPiAq4GdgCbAbeGRGbFx32KPCrmflG4M+AW5dRdsUcWV5p0EtSX5se/RXAnsx8JDOngduBrYMHZOY3M/MnzdtvA+vbll1JjtFL0tHaBP064PGB91PNtmP5APCl5ZaNiG0RMRkRk/v3729RraP119Eb9JJ0RJugjyW2LTnbGRFvoQ76P15u2cy8NTMnMnNi7dq1Lap1tPnJ2MNOxkpSX7fFMVPAxQPv1wN7Fx8UEW8EPglsycynl1N2pbiOXpKO1qZHvwvYFBGXRMQYcB2wY/CAiNgAfAF4d2Y+uJyyK2n+zFgnYyVp3nF79Jk5GxE3AHcCFXBbZt4XEdc3+7cDfwK8Cvi7iACYbYZhliy7Sm2h6gQRLq+UpEFthm7IzJ3AzkXbtg+8/iDwwbZlV0tE0Ks6Br0kDSjqzFiox+m9TLEkHVFe0Hc7TsZK0oDigr5XhUEvSQMKDHrH6CVpUHFBP1Z1mPGEKUmaV1zQ96qO6+glaUB5Qd8Nh24kaUB5QV+56kaSBhUX9GNVh2mHbiRpXnlB7zp6SVqguKDvuepGkhYoMOg9YUqSBhUY9J4wJUmDigt6J2MlaaHigt7llZK0UHlB3w0nYyVpQHFBP1ZVXgJBkgYUF/ReAkGSFiou6Mcco5ekBYoL+l7VYS7h8Jzj9JIEhQY9YK9ekhoFBn0AcMgJWUkCCgz6sW7dpFl79JIEFBj03U5/6MYxekmCEoO+GbpxjF6SasUF/ZiTsZK0QHFB3+/Rz7q8UpKAAoPe5ZWStFCBQd8fo7dHL0lQYND3V924vFKSasUF/ZGhG3v0kgRFBr3LKyVpUHFB32169LNzBr0kQcugj4hrImJ3ROyJiBuX2P+GiPhWRByKiI8u2vdYRHwvIr4bEZMrVfFj6ffop2cdupEkgO7xDoiICrgZeBswBeyKiB2Zef/AYc8AHwbecYyPeUtmPnWylW2jZ49ekhZo06O/AtiTmY9k5jRwO7B18IDM3JeZu4CZVajjsswHvZOxkgS0C/p1wOMD76eabW0lcFdE3B0R2451UERsi4jJiJjcv3//Mj5+oW6nGbpxMlaSgHZBH0tsW053+erMvBzYAnwoIt681EGZeWtmTmTmxNq1a5fx8QvZo5ekhdoE/RRw8cD79cDetj8gM/c2z/uAO6iHglaNyyslaaE2Qb8L2BQRl0TEGHAdsKPNh0fEWRFxTv818Hbg3hOtbBtdr3UjSQscd9VNZs5GxA3AnUAF3JaZ90XE9c3+7RFxITAJnAvMRcRHgM3ABcAdEdH/WZ/JzC+vTlNqY/Orbhy6kSRoEfQAmbkT2Llo2/aB109SD+ksdgC49GQquFzzNx7xnrGSBJR4Zmyz6mbGHr0kAQUGfUTQq8IxeklqFBf0UF+q2MsUS1KtzKCvwssUS1KjyKAfqzoO3UhSo8ig71bhmbGS1Cgy6HtVhxmvXilJQMlBb49ekoBCg77bCVfdSFKjyKDvORkrSfMKDXqXV0pSX6FB3/FWgpLUKDLou1Uw483BJQkoNOhdXilJR5Qb9E7GShJQaNDXyysdupEkKDTo7dFL0hGFBr3LKyWpr8ig71Zej16S+ooM+nrVjT16SYJig95bCUpSX5FBX99K0B69JEGhQd/rBtP26CUJKDXovTm4JM0rM+irDnMJh52QlaQyg75bBYATspJEoUHfa4J+1h69JJUa9HWzZmbt0UtSkUHf7Qe9lyqWpDKDfqw/dONaekkqM+i7naZH72SsJBUa9POrbuzRS1KRQT9W2aOXpL5WQR8R10TE7ojYExE3LrH/DRHxrYg4FBEfXU7Z1dCfjHWMXpJaBH1EVMDNwBZgM/DOiNi86LBngA8DnziBsitufujGVTeS1KpHfwWwJzMfycxp4HZg6+ABmbkvM3cBM8stuxrGXEcvSfPaBP064PGB91PNtjZal42IbRExGRGT+/fvb/nxS+t2PDNWkvraBH0ssa1tgrYum5m3ZuZEZk6sXbu25ccvrdd1MlaS+toE/RRw8cD79cDelp9/MmVPWG9+Hb09eklqE/S7gE0RcUlEjAHXATtafv7JlD1h3fkzY+3RS1L3eAdk5mxE3ADcCVTAbZl5X0Rc3+zfHhEXApPAucBcRHwE2JyZB5Yqu1qN6etf1My7TElSi6AHyMydwM5F27YPvH6SelimVdnV1vNaN5I0r8gzY/s9+lnX0UtSmUHfH6OftkcvSWUGfX/VjZOxklRq0LuOXpLmFRn0/TNjXUcvSYUGfc+rV0rSvCKDvuoEnXDoRpKg0KCH+pr0XqZYkgoO+rGq49CNJFFw0HercOhGkig56DsdV91IEgUH/Zg9ekkCCg76btXxzFhJouCg71XBjLcSlKSSg77jzcEliYKDvluFNweXJAoO+l7VcTJWkig56DsGvSRByUHfDc+MlSQKDvquPXpJAgoO+l4VnhkrSRQd9PboJQkKDvpu1XF5pSRRcND3Ol7rRpKg5KB36EaSgIKDvlu5vFKSoOCgt0cvSbWCg97llZIEBQd9verGHr0kFRv09dBNkmmvXtJoKzfoOwHgWnpJI6/coO/WTXPljaRRV2zQd5se/bQrbySNuFZBHxHXRMTuiNgTETcusT8i4qZm/z0RcfnAvsci4nsR8d2ImFzJyr+UXtXv0Rv0kkZb93gHREQF3Ay8DZgCdkXEjsy8f+CwLcCm5vEm4Jbmue8tmfnUitW6hX7Qu8RS0qhr06O/AtiTmY9k5jRwO7B10TFbgU9n7dvAeRFx0QrXdVm6VT1040lTkkZdm6BfBzw+8H6q2db2mATuioi7I2LbiVZ0uXqVq24kCVoM3QCxxLbF6flSx1ydmXsj4tXAVyLigcz8xlE/pP5HYBvAhg0bWlTrpR0ZurFHL2m0tenRTwEXD7xfD+xte0xm9p/3AXdQDwUdJTNvzcyJzJxYu3Ztu9q/hG7HoJckaBf0u4BNEXFJRIwB1wE7Fh2zA3hPs/rmSuDZzHwiIs6KiHMAIuIs4O3AvStY/2Na06yjn5416CWNtuMO3WTmbETcANwJVMBtmXlfRFzf7N8O7ASuBfYALwDvb4q/BrgjIvo/6zOZ+eUVb8US1vTqoD84Y9BLGm1txujJzJ3UYT64bfvA6wQ+tES5R4BLT7KOJ2S8VwFwcPbwMH68JJ0yij0ztj90c8gevaQRV2zQ93v0h+zRSxpxxQf9wRmDXtJoKzfou07GShIUHPRrHLqRJKDgoLdHL0m1YoO+W3XodsIxekkjr9igh3qJpT16SaOu6KAf71WO0UsaecUHvT16SaOu6KBf0+t4CQRJI6/soO9WHHIyVtKIKzrox3sdDnmZYkkjruyg71Yur5Q08soO+p7LKyWp6KA/c02X56dnh10NSRqqooP+3PEuzx006CWNtqKD/uw1XZ47ODPsakjSUBUd9OeM9zg4M8fMYcfpJY2uwoO+viXuzxy+kTTCig76s9c0QX/IoJc0uooO+nPGewAccJxe0ggrPOgdupGkkQh6l1hKGmWFB309dOMYvaRRVnTQn9v06J990TF6SaOr6KA//8wxelXw5IGDw66KJA1N0UHf6QQXvmKcvT99cdhVkaShKTroAV77ijN44qf26CWNrvKD/rwz+JE9ekkjrPigv+gV4/z4wEFmvd6NpBFVfNC/fu3ZzM4ljz39wrCrIklDUXzQv+HCcwB44MkDQ66JJA1H8UH/c68+m6oTPPDEc8OuiiQNRaugj4hrImJ3ROyJiBuX2B8RcVOz/56IuLxt2dU23qv4hdeey533PcncXL7cP16Shu64QR8RFXAzsAXYDLwzIjYvOmwLsKl5bANuWUbZVfe+qzfy0L6f8Zd37eaxp57nwMEZfvj0Czzx7Ivs2fccUz95gUOzh1/uaknSy6Lb4pgrgD2Z+QhARNwObAXuHzhmK/DpzEzg2xFxXkRcBGxsUXbVbb10Hf/14FPc8rWHueVrDx/zuDN6FWPdDr2qQ9WBIOgERAQR0Gmeg3pbycpuHUU3sOCmAWX/3XvlmWN87vqrVvxz2wT9OuDxgfdTwJtaHLOuZVkAImIb9f8G2LBhQ4tqtdfpBH/1u5fy7qtex8P7n+eZ5w9x3hljzMzNcfaaLs8fOszTPzvEsy/OMH14jpnDSWYyl0kmzCUk/df1c8kKbx5Z8C+w3JY1Cm9g/4q7K63Npy71z+fir/tYx7QpW2/MvBW4FWBiYmLFf50RwWUbzueyDeev9EdL0imtTdBPARcPvF8P7G15zFiLspKkVdRm1c0uYFNEXBIRY8B1wI5Fx+wA3tOsvrkSeDYzn2hZVpK0io7bo8/M2Yi4AbgTqIDbMvO+iLi+2b8d2AlcC+wBXgDe/1JlV6UlkqQlxak4MTUxMZGTk5PDroYknTYi4u7MnFhqX/FnxkrSqDPoJalwBr0kFc6gl6TCnZKTsRGxH/jBCRa/AHhqBatzOrDNo8E2l+9k2vu6zFy71I5TMuhPRkRMHmvmuVS2eTTY5vKtVnsdupGkwhn0klS4EoP+1mFXYAhs82iwzeVblfYWN0YvSVqoxB69JGmAQS9JhSsm6Id9E/LVEhEXR8R/RsT3I+K+iPj9ZvsrI+IrEfFQ83z+QJmPNd/D7oj4zeHV/uRERBUR/xsRX2zeF93m5hacn4+IB5rf91Uj0OY/aP5c3xsRn42I8dLaHBG3RcS+iLh3YNuy2xgRvxQR32v23RTLuadiZp72D+pLID8MvJ76Zif/B2wedr1WqG0XAZc3r88BHqS+0fpfADc2228E/rx5vblp/xrgkuZ7qYbdjhNs+x8CnwG+2Lwvus3APwIfbF6PAeeV3GbqW40+CpzRvP8c8L7S2gy8GbgcuHdg27LbCPw3cBX1nfu+BGxpW4dSevTzNzDPzGmgfxPy015mPpGZ/9O8fg74PvVfkK3UwUDz/I7m9Vbg9sw8lJmPUt8j4IqXt9YnLyLWA78FfHJgc7FtjohzqQPhUwCZOZ2ZP6XgNje6wBkR0QXOpL4DXVFtzsxvAM8s2rysNkbERcC5mfmtrFP/0wNljquUoD/WzcmLEhEbgcuA7wCvyfouXjTPr24OK+W7+Bvgj4C5gW0lt/n1wH7gH5rhqk9GxFkU3ObM/BHwCeCHwBPUd6a7i4LbPGC5bVzXvF68vZVSgr71TchPVxFxNvCvwEcy88BLHbrEttPqu4iI3wb2ZebdbYssse20ajN1z/Zy4JbMvAx4nvq/9Mdy2re5GZfeSj1E8VrgrIh410sVWWLbadXmFo7VxpNqeylB3+YG5qetiOhRh/y/ZOYXms0/bv47R/O8r9lewndxNfA7EfEY9TDcr0fEP1N2m6eAqcz8TvP+89TBX3KbfwN4NDP3Z+YM8AXglym7zX3LbeNU83rx9lZKCfpib0LezKx/Cvh+Zv71wK4dwHub1+8F/n1g+3URsSYiLgE2UU/inDYy82OZuT4zN1L/Lv8jM99F2W1+Eng8In6+2fRW4H4KbjP1kM2VEXFm8+f8rdRzUCW3uW9ZbWyGd56LiCub7+o9A2WOb9gz0is4s30t9YqUh4GPD7s+K9iuX6H+L9o9wHebx7XAq4CvAg81z68cKPPx5nvYzTJm5k/FB/BrHFl1U3SbgV8EJpvf9b8B549Am/8UeAC4F/gn6tUmRbUZ+Cz1HMQMdc/8AyfSRmCi+Z4eBv6W5soGbR5eAkGSClfK0I0k6RgMekkqnEEvSYUz6CWpcAa9JBXOoJekwhn0klS4/wcT5vBhBa9EqwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses_te)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, 'model/v3')"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "79622e2cdd399c04171d1ae422da3eb99546339dc74b7429cf74d873e1d44f60"
  },
  "kernelspec": {
   "display_name": "Python 3.8.2 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
